# TUP CATE ANALYSIS - COMPLETE SAMPLE SIZE ANALYSIS WITH ALL GROUPING METHODS
# Updated to include all grouping methods from previous TUP code

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime
warnings.filterwarnings('ignore')

def theoretical_bound(m, beta, N, delta, OPT):
    """Compute theoretical bound (1 - (N*ln(2N/delta)/m)^beta) * OPT"""
    term = N * np.log(2 * N / delta) / m
    if term >= 1:
        return 0  # Bound becomes meaningless
    return (1 - term**beta) * OPT

class CompleteTUPSampleSizeAnalyzer:
    """Complete TUP CATE allocation with all grouping methods."""

    def __init__(self, random_seed=42):
        self.random_seed = random_seed
        np.random.seed(random_seed)
        print(f"Complete TUP Sample Size Analyzer initialized with seed {random_seed}")

    def process_tup_data(self, df, outcome_col=None):
        """Process TUP dataset - SAME AS ORIGINAL."""
        print(f"Processing TUP data with {len(df)} observations")
        print(f"Available columns: {len(df.columns)} columns")

        # DEBUG: Check what columns we actually have
        target_cols = [col for col in df.columns if 'target' in col.lower()]
        consumption_cols = [col for col in df.columns if 'consumption' in col.lower()]
        outcome_cols = [col for col in df.columns if 'outcome' in col.lower()]

        print(f"Columns with 'target': {target_cols}")
        print(f"Columns with 'consumption': {consumption_cols}")
        print(f"Columns with 'outcome': {outcome_cols}")

        df_processed = df.copy()

        # Check for required columns
        if 'treatment' not in df_processed.columns:
            raise ValueError("Missing required 'treatment' column")

        # FIRST PRIORITY: If outcome column already exists, use it directly
        if 'outcome' in df_processed.columns:
            print("Found existing 'outcome' column - using directly")
            # Look for baseline consumption
            baseline_cols = [col for col in df.columns if 'pc_exp_month_bl' in col or ('bl' in col and 'exp' in col)]
            if baseline_cols:
                df_processed['baseline_consumption'] = df_processed[baseline_cols[0]]
                print(f"Using {baseline_cols[0]} as baseline consumption")
            else:
                print("Warning: No baseline consumption found, but proceeding with existing outcome")
                df_processed['baseline_consumption'] = 0

        # SECOND PRIORITY: If target_column_consumption already exists
        elif 'target_column_consumption' in df_processed.columns:
            print("Found existing target_column_consumption - using as outcome")
            df_processed['outcome'] = df_processed['target_column_consumption']
            baseline_cols = [col for col in df.columns if 'pc_exp_month_bl' in col or ('bl' in col and 'exp' in col)]
            if baseline_cols:
                df_processed['baseline_consumption'] = df_processed[baseline_cols[0]]
                print(f"Using {baseline_cols[0]} as baseline consumption")
            else:
                print("Warning: No baseline consumption found, but proceeding with existing outcome")
                df_processed['baseline_consumption'] = 0

        # THIRD PRIORITY: If outcome column is explicitly specified
        elif outcome_col and outcome_col in df_processed.columns:
            print(f"Using provided outcome column: {outcome_col}")
            df_processed['outcome'] = df_processed[outcome_col]
            baseline_consumption_cols = [col for col in df.columns if 'bl' in col and any(x in col.lower() for x in ['exp', 'consumption', 'income'])]
            if baseline_consumption_cols:
                df_processed['baseline_consumption'] = df_processed[baseline_consumption_cols[0]]
            else:
                df_processed['baseline_consumption'] = 0

        # Final cleaning
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['outcome', 'treatment'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} additional rows due to missing outcome/treatment")

        print(f"Final dataset: {final_size} households")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")
        print(f"Outcome statistics: mean={df_processed['outcome'].mean():.3f}, std={df_processed['outcome'].std():.3f}")

        if 'baseline_consumption' in df_processed.columns:
            print(f"Baseline consumption stats: mean={df_processed['baseline_consumption'].mean():.3f}, std={df_processed['baseline_consumption'].std():.3f}")

        return df_processed

    def create_baseline_poverty_groups(self, df, n_groups=30, min_size=6):
        """Create groups by baseline consumption."""
        print(f"Creating baseline poverty groups (target: {n_groups})")

        if 'baseline_consumption' not in df.columns:
            print("No baseline consumption found, using first available consumption measure")
            consumption_cols = [col for col in df.columns if 'pc_exp' in col and 'bl' in col]
            if consumption_cols:
                baseline_col = consumption_cols[0]
                df['baseline_consumption'] = df[baseline_col]
            else:
                print("No baseline consumption measures found")
                return []

        consumption = df['baseline_consumption'].fillna(df['baseline_consumption'].median())
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(consumption, percentiles)
        bins = np.digitize(consumption, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'poverty_level_{i}',
                    'indices': indices,
                    'type': 'baseline_poverty'
                })

        print(f"Created {len(groups)} baseline poverty groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_demographics_groups(self, df, min_size=6):
        """Create groups by demographics."""
        print(f"Creating TUP demographics groups")

        potential_features = []
        demo_patterns = ['female', 'gender', 'head', 'age', 'education', 'literate',
                        'married', 'widow', 'household_size', 'children', 'caste', 'religion']

        for pattern in demo_patterns:
            matching_cols = [col for col in df.columns if pattern in col.lower() and not col.startswith('gkt')]
            potential_features.extend(matching_cols)

        potential_features = list(set(potential_features))

        available_features = []
        for col in potential_features:
            if col in df.columns and df[col].notna().sum() > 0:
                unique_vals = df[col].nunique()
                if 2 <= unique_vals <= 10:
                    available_features.append(col)

        if len(available_features) == 0:
            print("No suitable demographic variables found, creating simple binary splits")
            if 'baseline_consumption' in df.columns:
                median_consumption = df['baseline_consumption'].median()
                df['consumption_above_median'] = (df['baseline_consumption'] > median_consumption).astype(int)
                available_features = ['consumption_above_median']
            else:
                return []

        print(f"Using demographic features: {available_features}")
        if len(available_features) > 3:
            available_features = available_features[:3]

        df_clean = df.dropna(subset=available_features)
        print(f"After removing missing values: {len(df_clean)}/{len(df)} households")

        if len(df_clean) == 0:
            return []

        groups = []
        unique_combinations = df_clean[available_features].drop_duplicates()
        print(f"Found {len(unique_combinations)} unique demographic combinations")

        for combo_idx, (idx, combo) in enumerate(unique_combinations.iterrows()):
            mask = pd.Series(True, index=df.index)
            combo_description = []

            for feature in available_features:
                mask = mask & (df[feature] == combo[feature])
                combo_description.append(f"{feature}={combo[feature]}")

            indices = df[mask].index.tolist()
            combo_id = "_".join(combo_description)

            if len(indices) >= min_size:
                groups.append({
                    'id': combo_id,
                    'indices': indices,
                    'type': 'demographics'
                })

        print(f"Created {len(groups)} demographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_village_groups(self, df, min_size=6):
        """Create village groups."""
        print(f"Creating village-based groups from survey indicators (min_size={min_size})")

        location_patterns = ['district', 'rural', 'urban', 'capital', 'area', 'upazila', 'thana']
        location_cols = []

        for pattern in location_patterns:
            matching_cols = [col for col in df.columns
                           if pattern in col.lower() and 'bl_' in col
                           and not any(x in col.lower() for x in ['gram flour', 'food', 'loan'])]
            location_cols.extend(matching_cols)

        location_cols = list(set(location_cols))

        if location_cols:
            print(f"Found location indicator columns: {location_cols[:3]}...")
            location_col = location_cols[0]
            print(f"Using location indicator: {location_col}")

            groups = []
            for location_value in df[location_col].unique():
                if pd.isna(location_value):
                    continue

                indices = df[df[location_col] == location_value].index.tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'location_{location_col}_{location_value}',
                        'indices': indices,
                        'type': 'village'
                    })
        else:
            print("No location indicators found")

        print(f"Raw geographic groups created: {len(groups)}")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        print(f"Balanced geographic groups after filtering: {len(balanced_groups)}")
        return balanced_groups

    def create_causal_forest_groups(self, df, n_groups=30, min_size=6):
        """Create causal forest groups."""
        print(f"Creating causal forest groups (target: {n_groups})")

        exclude_patterns = ['outcome', 'treatment', 'el1', 'el2', 'el3', 'el4', 'total_score']
        feature_cols = [col for col in df.columns
                       if not any(pattern in col for pattern in exclude_patterns)]

        X = df[feature_cols].copy()

        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                try:
                    X[col] = LabelEncoder().fit_transform(X[col].astype(str))
                except:
                    X[col] = 0
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(0)

        treated_mask = df['treatment'] == 1
        control_mask = df['treatment'] == 0

        if treated_mask.sum() < 5 or control_mask.sum() < 5:
            print("Not enough treated or control observations for causal forest")
            return []

        rf_treated = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)
        rf_control = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)

        rf_treated.fit(X[treated_mask], df.loc[treated_mask, 'outcome'])
        rf_control.fit(X[control_mask], df.loc[control_mask, 'outcome'])

        pred_cate = rf_treated.predict(X) - rf_control.predict(X)
        cluster_features = np.column_stack([X.values, pred_cate.reshape(-1, 1)])
        cluster_features = StandardScaler().fit_transform(cluster_features)

        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'causal_forest_{i}',
                    'indices': indices,
                    'type': 'causal_forest'
                })

        print(f"Created {len(groups)} causal forest groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_propensity_groups(self, df, n_groups=50, min_size=6):
        """Create groups based on propensity score strata."""
        print(f"Creating propensity score groups (target: {n_groups})")

        feature_cols = [col for col in df.columns
                       if col not in ['treatment', 'outcome', 'target_column_consumption']]

        X = df[feature_cols].copy()

        # Handle different data types properly
        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                X[col] = LabelEncoder().fit_transform(X[col].astype(str))
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        # Fill missing values properly
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Get propensity scores
        prop_scores = cross_val_predict(
            LogisticRegression(random_state=self.random_seed),
            X, df['treatment'], method='predict_proba', cv=5
        )[:, 1]

        # Create strata
        quantiles = np.linspace(0, 1, n_groups + 1)
        bins = np.digitize(prop_scores, np.quantile(prop_scores, quantiles)) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'propensity_{i}',
                    'indices': indices,
                    'type': 'propensity'
                })

        print(f"Created {len(groups)} propensity groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_asset_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on asset ownership patterns."""
        print(f"Creating asset groups (target: {n_groups})")

        # Look for asset-related columns
        asset_patterns = ['asset', 'own', 'land', 'livestock', 'house', 'roof', 'wall', 'floor', 'toilet']
        asset_cols = []

        for pattern in asset_patterns:
            matching_cols = [col for col in df.columns if pattern in col.lower() and 'bl' in col.lower()]
            asset_cols.extend(matching_cols)

        asset_cols = list(set(asset_cols))

        if not asset_cols:
            print("No asset columns found, using baseline consumption as proxy")
            return self.create_baseline_poverty_groups(df, n_groups, min_size)

        print(f"Found {len(asset_cols)} asset-related columns")

        # Take subset of asset columns if too many
        if len(asset_cols) > 10:
            asset_cols = asset_cols[:10]

        X = df[asset_cols].copy()

        # Handle missing values and encode
        for col in X.columns:
            if X[col].dtype == 'object' or X[col].dtype.name == 'category':
                X[col] = LabelEncoder().fit_transform(X[col].astype(str))
            elif X[col].dtype == 'bool':
                X[col] = X[col].astype(int)

        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(0)

        # Cluster based on asset patterns
        X_scaled = StandardScaler().fit_transform(X)
        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(X_scaled)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'asset_cluster_{i}',
                    'indices': indices,
                    'type': 'asset'
                })

        print(f"Created {len(groups)} asset groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_household_composition_groups(self, df, min_size=6):
        """Create groups based on household composition."""
        print(f"Creating household composition groups")

        # Look for household composition variables
        hh_patterns = ['household_size', 'children', 'adult', 'elderly', 'dependent', 'member']
        hh_cols = []

        for pattern in hh_patterns:
            matching_cols = [col for col in df.columns if pattern in col.lower() and 'bl' in col.lower()]
            hh_cols.extend(matching_cols)

        hh_cols = list(set(hh_cols))

        if not hh_cols:
            print("No household composition columns found")
            return []

        print(f"Found household composition columns: {hh_cols}")

        # Create categorical splits based on household size and composition
        groups = []

        # Try household size first
        size_cols = [col for col in hh_cols if 'size' in col.lower()]
        if size_cols:
            size_col = size_cols[0]
            size_values = df[size_col].fillna(df[size_col].median())

            # Create size categories
            size_categories = pd.cut(size_values, bins=3, labels=['Small', 'Medium', 'Large'])

            for category in size_categories.categories:
                indices = df.index[size_categories == category].tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'household_size_{category}',
                        'indices': indices,
                        'type': 'household_composition'
                    })

        # Try children categories
        children_cols = [col for col in hh_cols if 'children' in col.lower()]
        if children_cols and len(groups) < 5:
            children_col = children_cols[0]
            children_values = df[children_col].fillna(0)

            # Create children categories
            for threshold in [0, 2]:
                if threshold == 0:
                    mask = children_values == 0
                    label = 'no_children'
                else:
                    mask = children_values >= threshold
                    label = f'children_gte_{threshold}'

                indices = df.index[mask].tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'household_{label}',
                        'indices': indices,
                        'type': 'household_composition'
                    })

        print(f"Created {len(groups)} household composition groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Ensure balance and compute CATE."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            cate = treated_outcomes.mean() - control_outcomes.mean()

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1]."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.3f}, {max_cate:.3f}] → [0, 1]")
        return groups

    def simulate_sampling_trial(self, groups, sample_size, trial_seed):
        """Simulate sampling trial"""
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Initialize tau estimates
        tau_estimates = np.zeros(n_groups)
        sample_counts = np.zeros(n_groups)

        # Perform sampling: choose group uniformly, sample Bernoulli(tau(u))
        for _ in range(sample_size):
            group_idx = np.random.randint(n_groups)
            sample = np.random.binomial(1, tau_true[group_idx])

            sample_counts[group_idx] += 1
            if sample_counts[group_idx] == 1:
                tau_estimates[group_idx] = sample
            else:
                tau_estimates[group_idx] = ((sample_counts[group_idx] - 1) * tau_estimates[group_idx] + sample) / sample_counts[group_idx]

        # Groups with no samples get estimate 0
        tau_estimates[sample_counts == 0] = 0

        return tau_estimates, sample_counts

    def analyze_sample_size_performance(self, groups, sample_sizes, budget_percentages, n_trials=50):
        """Analyze performance vs sample size."""
        print(f"Analyzing sample size performance with {len(groups)} groups")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Calculate budgets
        budgets = [max(1, int(p * n_groups)) for p in budget_percentages]
        print(f"Budget percentages {budget_percentages} → K values {budgets}")

        # Calculate optimal values
        optimal_values = {}
        for i, K in enumerate(budgets):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_values[budget_percentages[i]] = np.sum(tau_true[optimal_indices])

        # Run trials
        results = {bp: {'sample_sizes': [], 'values': [], 'stds': []} for bp in budget_percentages}

        for sample_size in sample_sizes:
            print(f"  Sample size {sample_size}...")

            budget_trial_values = {bp: [] for bp in budget_percentages}

            for trial in range(n_trials):
                tau_estimates, sample_counts = self.simulate_sampling_trial(groups, sample_size, trial)

                for i, K in enumerate(budgets):
                    bp = budget_percentages[i]

                    # Select top K based on estimates
                    selected_indices = np.argsort(tau_estimates)[-K:]

                    # Compute realized value with true tau
                    realized_value = np.sum(tau_true[selected_indices])
                    budget_trial_values[bp].append(realized_value)

            # Store results
            for bp in budget_percentages:
                results[bp]['sample_sizes'].append(sample_size)
                results[bp]['values'].append(np.mean(budget_trial_values[bp]))
                results[bp]['stds'].append(np.std(budget_trial_values[bp]))

        return results, optimal_values

    def plot_sample_size_analysis(self, results, optimal_values, method_name, budget_percentages, n_groups):
        """Create 6 plots (one per budget) for sample size analysis with theoretical bounds"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()

        # Calculate parameters for theoretical bounds
        delta = 0.05

        print(f"\nPlotting {method_name} (N={n_groups})")
        print("="*60)

        for i, bp in enumerate(budget_percentages):
            ax = axes[i]

            # Get data for this budget
            sample_sizes = results[bp]['sample_sizes']
            values = results[bp]['values']
            stds = results[bp]['stds']
            optimal_val = optimal_values[bp]

            # Normalize all values by optimal value
            values_norm = np.array(values) / optimal_val
            stds_norm = np.array(stds) / optimal_val

            # Plot empirical performance curve
            ax.errorbar(sample_sizes, values_norm, yerr=stds_norm,
                      marker='o', capsize=5, capthick=3, linewidth=6, markersize=8,
                      label='Empirical data', color='blue', alpha=0.8)

            # Plot optimal value (normalized to 1)
            ax.axhline(y=1.0, color='black', linestyle=':', linewidth=2,
                      label='Optimal (1.0)', alpha=0.8)

            # Create smooth curves for plotting theoretical bounds
            m_smooth = np.linspace(min(sample_sizes), max(sample_sizes), 200)

            # Plot reference curves and normalize
            ref_curve_05 = [theoretical_bound(m, 0.5, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]
            ref_curve_10 = [theoretical_bound(m, 1.0, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]

            ax.plot(m_smooth, ref_curve_05, 'red', linestyle=(0, (3, 2)), linewidth=6,
                  label='FullCATE', alpha=0.8)
            ax.plot(m_smooth, ref_curve_10, 'green', linestyle=(0, (3, 1, 1, 1)), linewidth=6,
                  label='ALLOC', alpha=0.8)

            # Set labels
            ax.set_xlabel('Sample size', fontsize=23)
            ax.set_ylabel('Normalized allocation value', fontsize=23)
            ax.set_title(f'Budget = {bp*100:.0f}% (K={max(1, int(bp * n_groups))})', fontsize=24, fontweight='bold')

            ax.legend(fontsize=21, framealpha=0.9)
            ax.grid(True, alpha=0.4, linewidth=1)

            ax.tick_params(axis='both', which='major', labelsize=16, width=1.5, length=5)

            y_min = 0.2
            y_max = 1.05  # Slightly above optimal
            ax.set_ylim(y_min, y_max)

            for spine in ax.spines.values():
                spine.set_linewidth(1.5)

        plt.suptitle(f'{method_name} (N={n_groups})', fontsize=24, fontweight='bold')
        plt.tight_layout()

        clean_name = method_name.replace(' ', '_').replace('(', '').replace(')', '').replace('-', '_')
        pdf_filename = f"{clean_name}_N{n_groups}_sample_size_analysis.pdf"
        plt.savefig(pdf_filename, format='pdf', dpi=300, bbox_inches='tight')
        print(f"Saved plot as: {pdf_filename}")

        plt.show()

        print(f"Plot complete for {method_name}")


def run_complete_tup_sample_size_analysis(df_tup, sample_size_range=None, budget_percentages=None, n_trials=50, outcome_col=None):
    """Run comprehensive sample size analysis on TUP dataset with all grouping methods."""

    if sample_size_range is None:
        sample_size_range = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]

    if budget_percentages is None:
        budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    print("COMPLETE TUP SAMPLE SIZE ANALYSIS - ALL GROUPING METHODS")
    print(f"Sample sizes: {sample_size_range}")
    print(f"Budget percentages: {budget_percentages}")
    print(f"Trials per sample size: {n_trials}")
    print("="*80)

    # Define ALL TUP grouping methods
    methods = [
        ('Village Groups', lambda analyzer, df: analyzer.create_village_groups(df, min_size=6)),
        ('Baseline Poverty (30)', lambda analyzer, df: analyzer.create_baseline_poverty_groups(df, n_groups=30, min_size=6)),
        ('Baseline Poverty (50)', lambda analyzer, df: analyzer.create_baseline_poverty_groups(df, n_groups=50, min_size=6)),
        ('Demographics', lambda analyzer, df: analyzer.create_demographics_groups(df, min_size=6)),
        ('Causal Forest (30)', lambda analyzer, df: analyzer.create_causal_forest_groups(df, n_groups=30, min_size=6)),
        ('Causal Forest (50)', lambda analyzer, df: analyzer.create_causal_forest_groups(df, n_groups=50, min_size=6)),
        ('Propensity Score (30)', lambda analyzer, df: analyzer.create_propensity_groups(df, n_groups=30, min_size=6)),
        ('Propensity Score (50)', lambda analyzer, df: analyzer.create_propensity_groups(df, n_groups=50, min_size=6)),
        ('Asset Groups (30)', lambda analyzer, df: analyzer.create_asset_groups(df, n_groups=30, min_size=6)),
        ('Asset Groups (50)', lambda analyzer, df: analyzer.create_asset_groups(df, n_groups=50, min_size=6)),
        ('Household Composition', lambda analyzer, df: analyzer.create_household_composition_groups(df, min_size=6))
    ]

    all_results = {}

    for method_name, method_func in methods:
        print(f"\n{'='*80}")
        print(f"ANALYZING TUP METHOD: {method_name}")
        print("="*80)

        try:
            analyzer = CompleteTUPSampleSizeAnalyzer()
            df_processed = analyzer.process_tup_data(df_tup, outcome_col=outcome_col)

            groups = method_func(analyzer, df_processed)

            if len(groups) < 10:
                print(f"Too few groups ({len(groups)}) for {method_name} - skipping")
                continue

            groups = analyzer.normalize_cates(groups)

            # Run sample size analysis
            results, optimal_values = analyzer.analyze_sample_size_performance(
                groups, sample_size_range, budget_percentages, n_trials
            )

            all_results[method_name] = {
                'results': results,
                'optimal_values': optimal_values,
                'n_groups': len(groups)
            }

            # Create plots with theoretical bounds
            print(f"Creating plots for {method_name}...")
            analyzer.plot_sample_size_analysis(
                results, optimal_values, method_name, budget_percentages, len(groups)
            )

            # Print summary
            print(f"\nSummary for {method_name}:")
            print(f"Number of groups: {len(groups)}")
            print("Optimal values by budget:")
            for bp in budget_percentages:
                print(f"  {bp*100:.0f}%: {optimal_values[bp]:.3f}")

        except Exception as e:
            print(f"Error with {method_name}: {e}")
            continue

    return all_results

def preprocess_tup_data(filepath):
    """Preprocess TUP data using the same approach as the original code."""
    # Load the data
    df1 = pd.read_stata(filepath)
    print(f"Loaded TUP data: {df1.shape}")

    # Follow the original preprocessing steps
    columns_to_drop_original = df1.columns[df1.columns.str.endswith('el1') |
                                  df1.columns.str.endswith('el2') |
                                  df1.columns.str.endswith('el3') |
                                  df1.columns.str.endswith('el4') | df1.columns.str.startswith('el')]

    categorical_columns = df1.select_dtypes(include=['object']).columns
    unique_value_counts = df1[categorical_columns].nunique()

    leave_categories = []
    keep_categories = []
    for i in categorical_columns:
        if unique_value_counts[i] > 2 * unique_value_counts.mean():
            leave_categories.append(i)
        else:
            keep_categories.append(i)

    # Keep only relevant columns
    object_columns = df1.select_dtypes(include=['object'])
    cols_keep = []
    for col in df1.columns:
        if col not in object_columns and col not in leave_categories and col not in columns_to_drop_original:
            cols_keep.append(col)

    df1_filtered = df1[cols_keep]
    df1_encoded = pd.get_dummies(df1_filtered)

    # Create target variable
    target_col_consumption = df1['pc_exp_month_el3'] - df1['pc_exp_month_bl']
    df1_encoded['target_column_consumption'] = target_col_consumption

    # Clean data
    df1_encoded = df1_encoded[df1_encoded['target_column_consumption'].notna()]
    df1_encoded = df1_encoded.dropna(axis=1, how='all')
    df1_encoded = df1_encoded.fillna(df1_encoded.mean())

    # Feature selection using Random Forest
    X = df1_encoded.drop(columns='target_column_consumption')
    y = df1_encoded['target_column_consumption']

    rf = RandomForestRegressor(n_estimators=100, random_state=42)
    rf.fit(X, y)

    feature_importances = rf.feature_importances_
    importance_df = pd.DataFrame({'Feature': X.columns, 'Importance': feature_importances})
    top_1000_features = importance_df.sort_values(by='Importance', ascending=False).head(1000)['Feature'].tolist()

    # Ensure treatment and baseline consumption are included
    if 'treatment' not in top_1000_features:
        top_1000_features.append('treatment')
    if 'pc_exp_month_bl' not in top_1000_features:
        top_1000_features.append('pc_exp_month_bl')

    df_final = df1_encoded[top_1000_features].copy()

    df_final['target_column_consumption'] = df1_encoded['target_column_consumption']
    df_final['outcome'] = df1_encoded['target_column_consumption']  # Also create 'outcome' column

    return df_final


# Example usage
if __name__ == "__main__":
    # Load and preprocess TUP data
    df_tup = preprocess_tup_data('TUP_HH_Constructed.dta')

    # Run comprehensive sample size analysis with all grouping methods
    sample_sizes = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]
    budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    results = run_complete_tup_sample_size_analysis(
        df_tup,
        sample_size_range=sample_sizes,
        budget_percentages=budget_percentages,
        n_trials=50
    )

    print("\n" + "="*80)
    print("COMPLETE TUP ANALYSIS WITH ALL GROUPING METHODS FINISHED")
    print("="*80)